import matplotlib.pyplot as plt
import pandas as pd

plt.rcParams['font.family'] = 'SimHei'
plt.rcParams['font.size'] = 15

data = pd.read_csv('../data/train.csv')


# data.info()

def png(name, x_label, title):
    """
    构建柱状图并保存
    :param name:数据名称
    :param x_label: X轴
    :param title: 标题
    :return:
    """
    # 创建图形和主坐标轴
    fig, ax1 = plt.subplots(figsize=(30, 15))

    bar_data = data.groupby(name)['Attrition'].mean()
    line_data = data.groupby(name)['Attrition'].sum()
    # 绘制柱状图

    lines = ax1.bar(line_data.index, line_data.values, color='skyblue', label='人数')
    ax1.set_ylabel('人数')
    ax1.tick_params(axis='y')
    ax1.yaxis.set_major_locator(plt.MaxNLocator(integer=True))

    ax2 = ax1.twinx()
    bars = ax2.plot(bar_data.index, bar_data.values, color='red', linewidth=3, marker='o', label='比例')

    ax2.set_xlabel(x_label)
    ax2.set_ylabel('离职比率')
    ax2.tick_params(axis='y')

    plt.title(title)

    # 合并图例
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()

    ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left')

    plt.tight_layout()
    plt.savefig(f'../data/img/{title}.png')
    plt.show()


if __name__ == '__main__':
    png('BusinessTravel', '商务旅行情况', '商务旅行情况与离职比例&人数')
    png('Department', '部门', '部门与离职比例&人数')
    png('EducationField', '教育领域', '教育领域与离职比例&人数')
    png('Gender', '性别', '性别与离职比例&人数')
    png('JobRole', '工作角色', '工作角色与离职比例&人数')
    png('MaritalStatus', '婚姻状况', '婚姻状况与离职比例&人数')
    png('OverTime', '加班', '加班与离职比例&人数')
